Solving ODEs in julia with DifferentialEquations.jl¶

Why use DifferentialEquations.jl?¶

  • Fast
  • Easy to use
  • Lots of features!
  • Has a python wrapper diffeqpy

Ordinary Differential Equations¶

• Equaitons with the form: $$ \frac{dx_i(t)}{dt} = f_i(\vec{x(t)}, \theta) $$

• Simple ODEs easy to solve

Example: Exponential growth¶

$$ \frac{dx(t)}{dt} = rx \qquad \longrightarrow \qquad x(t) = x(0)e^{rt} $$
In [354]:
plot(sol, legend = false, xlab = "time", ylab = "x(t)", lw = 2, size = (500,400))
Out[354]:

Simple ODE system¶

Generalised Lotka-Volterra¶

  • Simple model of ecosystem dynamics
  • Need to use numerical simualtions to solve
$$ \frac{dx_i}{dt} = x_i \left(r_i + \sum_{i \neq j}^N a_{ij} x_j \right) $$
Parameters
$x_i$ Biomass of $i^\text{th}$ species
$N$ Number of Species
$r_i$ Growth rate of $i^\text{th}$ species
$a_{ij}$ Interaction between the $i^\text{th}$ and $j^\text{th}$ species
In [355]:
#load packages
using Pkg
Pkg.activate(".")

using DifferentialEquations
using Plots
using Random
In [356]:
#parameters
struct param
    N::Int
    r::Vector{Float64}
    a::Array{Float64,2}
end

#derivative function
function dx!(dx,x,p,t)
    for i = 1:p.N
        dx[i] = x[i] * p.r[i]
        for j = 1:p.N
            dx[i] += x[i] * x[j] * p.a[i,j]
        end
    end
end
Out[356]:
dx! (generic function with 1 method)
In [357]:
Random.seed!(2)

#generate params
N = 10
r = rand(N)
a = -rand(N,N)
[a[i,i] = -2 for i = 1:N]
  
p = param(N,r,a)

#define ODE problem
tspan = (0.0,200.0)
x0 = rand(N) ./ 10
prob = ODEProblem(dx!,x0,tspan,p)
sol = solve(prob)

fieldnames(typeof(sol))
Out[357]:
(:u, :u_analytic, :errors, :t, :k, :prob, :alg, :interp, :alg_choice, :dense, :tslocation, :destats, :retcode)
In [358]:
plot(sol, legend = false)
Out[358]:

Callbacks¶

  • Ability to run code during simualtions
    • Numerically stable
    • Efficent (dont have to restart the integrator)

Callback structure¶

Callbacks are made up of two parts:

  1. condition(integrator): a function that detects when the callback should be used based on the state of the system
  2. affect!(integrator): a function that modifies the state of the system

Example 1: Extinction Detection¶

Detect when species go extinct and reintroduce to the system

In [359]:
#get the smallest biomass
function condition(x,t,integrator)
  any(x .<= 0)
end

#add 1 to the smallest biomass
function affect!(integrator)
  integrator.u[findmin(integrator.u)[2]] = 0.5
end
Out[359]:
affect! (generic function with 1 method)
In [363]:
Random.seed!(2)

#generate params
N = 10
r = rand(N)
a = -rand(N,N)
[a[i,i] = -2 for i = 1:N]
  
p = param(N,r,a)

#define ODE problem
tspan = (0.0,1000.0)
x0 = rand(N)

prob = ODEProblem(dx!,x0,tspan,p)
cb = DiscreteCallback(condition,affect!)
sol = @time solve(prob, callback = cb);
  0.001199 seconds (4.24 k allocations: 418.352 KiB)
In [364]:
plot(sol, legend = false)
Out[364]:

Example 2: Altering parameters¶

Detect extinction and change parameters of extinct species

In [365]:
#get the smallest biomass
function condition(x,t,integrator)
  any(x .<= 0)
end

#add 1 to the smallest biomass and alter parameters
function affect!(integrator)
    #new biomass
    integrator.u[findmin(integrator.u)[2]] += 0.1
    
    #altered parameters
    integrator.p.r[findmin(integrator.u)[2]] = rand()
    integrator.p.a[findmin(integrator.u)[2],:] .= -rand(integrator.p.N) 
    integrator.p.a[findmin(integrator.u)[2],findmin(integrator.u)[2]] = -2.0
end

#save avg growthrate and interactions in the community
function save_func(x, t, integrator)
    return(sum(integrator.p.r)/ length(integrator.p.r),
          (sum(integrator.p.a)+ 2*integrator.p.N) / (length(integrator.p.a) - integrator.p.N) )
end

saved_values = SavedValues(Float64, Tuple{Float64,Float64})
Out[365]:
SavedValues{tType=Float64, savevalType=Tuple{Float64, Float64}}
t:
Float64[]
saveval:
Tuple{Float64, Float64}[]
In [368]:
Random.seed!(2)

#generate params
N = 10
r = rand(N)
a = -rand(N,N)
[a[i,i] = -2 for i = 1:N]
  
p = param(N,r,a)

new_sp = DiscreteCallback(condition,affect!)
saving_cb = SavingCallback(save_func, saved_values)
cb = CallbackSet(new_sp, saving_cb)

tspan = (0.0,6e3)
x0 = rand(N)

prob = ODEProblem(dx!,x0,tspan,p)

sol = @time solve(prob, callback = cb);
  0.003688 seconds (15.22 k allocations: 1.633 MiB)
In [369]:
plot(sol, legend = false)
Out[369]:
In [370]:
t = saved_values.t
r = map(x -> x[1], saved_values.saveval)
ā = map(x -> x[2], saved_values.saveval)

p1 = plot(t,r, title = "growth rates")
p2 = plot(t,ā, title = "interactions")
plot(p1,p2, size = (800, 400), legend = false, xlab = "time")
Out[370]:

Example 3: Changing integrator size¶

Adding new species by altering the size of the integrator on the go

In [371]:
#add 1 to the smallest biomass and alter parameters
function affect!(integrator)
    #get biomass array
    tmp = deepcopy(integrator.u)     
    #detect extant species
    ext = tmp .> eps()
    #get new system size
    integrator.p.N = (1 + sum(ext))
    #resize integrator
    resize!(integrator,integrator.p.N)

    #put in old biomass
    integrator.u[1:(integrator.p.N-1)] .= tmp[ext]
    #put in new biomass
    integrator.u[end] = 0.01
    
    #update params
    #growth rate
    deleteat!(integrator.p.r, .!ext)
    push!(integrator.p.r, rand())
    
    #interactions
    new_a = -rand(integrator.p.N, integrator.p.N) 
    new_a[integrator.p.N, integrator.p.N] = -2.0
    #add old data  
    new_a[1:(integrator.p.N-1) , 1:(integrator.p.N-1) ] .= integrator.p.a[ext, ext]

    integrator.p.a = new_a
end

function save_func(x, t, integrator)
    return(sum(integrator.p.r)/ length(integrator.p.r),
          (sum(integrator.p.a)+ 2*integrator.p.N) / (length(integrator.p.a) - integrator.p.N) )
end

saved_values = SavedValues(Float64, Tuple{Float64,Float64})
Out[371]:
SavedValues{tType=Float64, savevalType=Tuple{Float64, Float64}}
t:
Float64[]
saveval:
Tuple{Float64, Float64}[]
In [374]:
Random.seed!(2)

#redefine parameters structure to be changeable
mutable struct param_mutable
    N::Int
    r::Vector{Float64}
    a::Array{Float64,2}
end

#generate params
N = 1
r = rand(N)
a = -rand(N,N)
[a[i,i] = -2 for i = 1:N]
  
p = param_mutable(N,r,a)

new_sp = PeriodicCallback(affect!, 100)
saving_cb = SavingCallback(save_func, saved_values)

cb = CallbackSet(new_sp, saving_cb)

tspan = (0.0,2e5)
x0 = rand(N)
prob = ODEProblem(dx!,x0,tspan,p)

sol = @time solve(prob, callback = cb);
  1.593246 seconds (1.16 M allocations: 227.388 MiB)
In [375]:
plot(sol.t, length.(sol.u), xlab = "time", ylab = "Number of Species", legend = false)
Out[375]:
In [376]:
t = saved_values.t
r = map(x -> x[1], saved_values.saveval)
a = map(x -> x[2], saved_values.saveval)

p1 = plot(t,r, title = "growth rate")
p2 = plot(t,a, title = "interactions")

plot(p1,p2, legend = false, size = (800,400))
Out[376]:
In [397]:
function simulate_with_extinctions(p)
    new_sp = PeriodicCallback(affect!, 100)
    saving_cb = SavingCallback(save_func, saved_values)

    cb = CallbackSet(new_sp, saving_cb)

    tspan = (0.0,6e3)
    x0 = rand(N)
    prob = ODEProblem(dx!,x0,tspan,deepcopy(p))

    solve(prob, callback = cb);
end

function affect_noext!(integrator)
    #get new system size
    integrator.p.N += 1 
    #resize integrator
    resize!(integrator,integrator.p.N)

    #put in new biomass
    integrator.u[end] = 0.01
    
    #update params
    #growth rate
    push!(integrator.p.r, rand())
    
    #interactions
    new_a = -rand(integrator.p.N, integrator.p.N) 
    new_a[integrator.p.N, integrator.p.N] = -2.0
    #add old data  
    new_a[1:(integrator.p.N-1) , 1:(integrator.p.N-1) ] .= integrator.p.a

    integrator.p.a = new_a
end

function simulate_no_extinctions(p)
    new_sp = PeriodicCallback(affect_noext!, 100)
    saving_cb = SavingCallback(save_func, saved_values)

    cb = CallbackSet(new_sp, saving_cb)

    tspan = (0.0,6e3)
    x0 = rand(N)
    prob = ODEProblem(dx!,x0,tspan,deepcopy(p))

    solve(prob, callback = cb);
end
Out[397]:
simulate_no_extinctions (generic function with 1 method)

Expanding System Comparison¶

In [401]:
#generate params
N = 1
r = rand(N)
a = -rand(N,N)
[a[i,i] = -2 for i = 1:N]
  
p = param_mutable(N,r,a)

println("with extinctions:")
@time simulate_with_extinctions(p);
println("without extinctions:")
@time simulate_no_extinctions(p);
with extinctions:
  0.011131 seconds (32.65 k allocations: 3.788 MiB)
without extinctions:
  0.807227 seconds (39.37 k allocations: 6.585 MiB)